/* Copyright 2012 Tobias Marschall
 *
 * This file is part of CLEVER.
 *
 * CLEVER is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * CLEVER is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with CLEVER.  If not, see <http://www.gnu.org/licenses/>.
 */

#ifndef NFAMATCHER_H_
#define NFAMATCHER_H_

#include <string>
#include <vector>
#include <memory>
#include <cassert>

/** Bit-parallel implementation of an NFA based matcher. All matches with up to
 *  a given edit distance are found. Implementation is specific to upper case
 *  DNA alphabet {A,C,G,T} and only works for patterns up to length of the
 *  machine word size.
 */
class NfaMatcher {
private:
	typedef unsigned int bitmask_t;
	std::vector<bitmask_t> bitmasks;
	size_t length;
	bool backwards;
	inline int index(char c) {
		switch(c) {
		case 'A': return 0;
		case 'C': return 1;
		case 'G': return 2;
		case 'T': return 3;
		default: return -1;
		}
	}
public:
	template <typename StringType>
	NfaMatcher(const StringType& pattern, size_t length, size_t start=0, bool backwards=false) : bitmasks(4,0) {
		this->length = length;
		this->backwards = backwards;
		assert(length+1 <= sizeof(bitmask_t)*8);
		if (backwards) {
			assert(((int)start) - ((int)length) + 1 >= 0);
			for (size_t i=0; i<length; ++i) {
				int j = index(pattern[start-i]);
				if (j==-1) continue;
				bitmasks[j] |= 1<<(i+1);
			}
		} else {
			assert(start+length <= pattern.size());
			for (size_t i=0; i<length; ++i) {
				int j = index(pattern[start+i]);
				if (j==-1) continue;
				bitmasks[j] |= 1<<(i+1);
			}
		}
	}

	virtual ~NfaMatcher() {}

	template <typename StringType>
	std::auto_ptr<std::vector<std::pair<size_t,size_t> > > findMatches(const StringType& reference, size_t max_errors, size_t start=0, int max_length=-1, bool only_first_hit=false) {
		assert(max_errors < length);
		std::auto_ptr<std::vector<std::pair<size_t,size_t> > >result(new std::vector<std::pair<size_t,size_t> >());
		std::vector<bitmask_t> state(max_errors+1, 1);
		bitmask_t match_mask = 1L<<length;
		// Initialization
		for (size_t i=1; i<=max_errors; ++i) {
			state[i] = (state[i-1]>>1) | 1;
		}
		// Determine the number of characters to be read on reference
		int ref_len;
		if (backwards) {
			ref_len = start + 1;
		} else {
			ref_len = reference.size() - start;
		}
		if ((max_length>=0) && (ref_len>max_length)) {
			ref_len = max_length;
		}
		int last_hit_pos = -1;
		int last_hit_errors = -1;
		// Iterate over all characters in the reference sequence
		for (int i=0; i<ref_len; ++i) {
			int j = index(reference[backwards?start-i:start+i]);
			bitmask_t mask = j<0?0:bitmasks[j];
			for (int k=max_errors; k>=0; --k) {
				state[k] = ((state[k]<<1) & mask) | 1;
				if (k>0) {
					state[k] |= state[k-1];
					state[k] |= state[k-1]<<1;
				}
			}
			for (size_t k=1; k<=max_errors; ++k) {
				state[k] |= state[k-1]<<1;
			}
			// Test for match
			bool new_match = false;
			for (size_t k=0; k<=max_errors; ++k) {
				if ((state[k] & match_mask) != 0) {
					if (only_first_hit) {
						if ((last_hit_errors==-1) || (last_hit_errors>(int)k)) {
							last_hit_pos = backwards?start-i:start+i;
							last_hit_errors = k;
							new_match = true;
						}
					} else {
						result->push_back(std::make_pair(backwards?start-i:start+i,k));
					}
					break;
				}
			}
			if (only_first_hit && (last_hit_pos>=0) && !new_match) break;
		}
		if (only_first_hit && (last_hit_pos>=0)) {
			result->push_back(std::make_pair(last_hit_pos,last_hit_errors));
		}
		return result;
	}
};

#endif /* NFAMATCHER_H_ */
